{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_gan/python/estimator/tpu_gan_estimator.py:42: The name tf.estimator.tpu.TPUEstimator is deprecated. Please use tf.compat.v1.estimator.tpu.TPUEstimator instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_gan/python/estimator/tpu_gan_estimator.py:42: The name tf.estimator.tpu.TPUEstimator is deprecated. Please use tf.compat.v1.estimator.tpu.TPUEstimator instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from tensor2tensor import models\n",
    "from tensor2tensor import problems\n",
    "from tensor2tensor.layers import common_layers\n",
    "from tensor2tensor.utils import trainer_lib\n",
    "from tensor2tensor.utils import t2t_model\n",
    "from tensor2tensor.utils import registry\n",
    "from tensor2tensor.utils import metrics\n",
    "from tensor2tensor.data_generators import problem\n",
    "from tensor2tensor.data_generators import text_problems\n",
    "from tensor2tensor.data_generators import translate\n",
    "from tensor2tensor.utils import registry"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sentencepiece as spm\n",
    "\n",
    "vocab = '../sp10m.cased.ms-en.model'\n",
    "sp = spm.SentencePieceProcessor()\n",
    "sp.Load(vocab)\n",
    "\n",
    "class Encoder:\n",
    "    def __init__(self, sp):\n",
    "        self.sp = sp\n",
    "        self.vocab_size = sp.GetPieceSize() + 100\n",
    "    \n",
    "    def encode(self, s):\n",
    "        return self.sp.EncodeAsIds(s)\n",
    "    \n",
    "    def decode(self, ids, strip_extraneous=False):\n",
    "        return self.sp.DecodeIds(list(ids))\n",
    "    \n",
    "encoder = Encoder(sp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from glob import glob\n",
    "\n",
    "@registry.register_problem\n",
    "class Seq2Seq(text_problems.Text2TextProblem):\n",
    "\n",
    "    @property\n",
    "    def approx_vocab_size(self):\n",
    "        return 32100\n",
    "    \n",
    "    @property\n",
    "    def is_generate_per_split(self):\n",
    "        return False\n",
    "            \n",
    "    def feature_encoders(self, data_dir):\n",
    "        encoder = Encoder(sp)\n",
    "        return {\n",
    "            \"inputs\": encoder,\n",
    "            \"targets\": encoder\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_DIR = os.path.expanduser('t2t-knowledge-graph/data')\n",
    "TMP_DIR = os.path.expanduser('t2t-knowledge-graph/tmp')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "PROBLEM = 'seq2_seq'\n",
    "t2t_problem = problems.problem(PROBLEM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'t2t-knowledge-graph/train-base/model.ckpt-325000'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import os\n",
    "\n",
    "ckpt_path = tf.train.latest_checkpoint('t2t-knowledge-graph/train-base')\n",
    "ckpt_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensor2tensor import models\n",
    "from tensor2tensor import problems\n",
    "from tensor2tensor.layers import common_layers\n",
    "from tensor2tensor.utils import trainer_lib\n",
    "from tensor2tensor.utils import t2t_model\n",
    "from tensor2tensor.utils import registry\n",
    "from tensor2tensor.utils import metrics\n",
    "from tensor2tensor.data_generators import problem\n",
    "from tensor2tensor.data_generators import text_problems\n",
    "from tensor2tensor.data_generators import translate\n",
    "from tensor2tensor.utils import registry"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py:507: calling count_nonzero (from tensorflow.python.ops.math_ops) with axis is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "reduction_indices is deprecated, use axis instead\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py:507: calling count_nonzero (from tensorflow.python.ops.math_ops) with axis is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "reduction_indices is deprecated, use axis instead\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting T2TModel mode to 'infer'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting T2TModel mode to 'infer'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.dropout to 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.dropout to 0.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.label_smoothing to 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.label_smoothing to 0.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.layer_prepostprocess_dropout to 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.layer_prepostprocess_dropout to 0.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.symbol_dropout to 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.symbol_dropout to 0.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.attention_dropout to 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.attention_dropout to 0.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.relu_dropout to 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting hparams.relu_dropout to 0.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using variable initializer: uniform_unit_scaling\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using variable initializer: uniform_unit_scaling\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Entity <function framework at 0x7fcec41c0d08> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: LIVE_VARS_IN\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Entity <function framework at 0x7fcec41c0d08> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: LIVE_VARS_IN\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: Entity <function framework at 0x7fcec41c0d08> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: LIVE_VARS_IN\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/utils/t2t_model.py:2253: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/utils/t2t_model.py:2253: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/utils/t2t_model.py:1374: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/utils/t2t_model.py:1374: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_32100_512.bottom\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_32100_512.bottom\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "If using Keras pass *_constraint arguments to layers.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "If using Keras pass *_constraint arguments to layers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'targets' with symbol_modality_32100_512.targets_bottom\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'targets' with symbol_modality_32100_512.targets_bottom\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Building model body\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Building model body\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/models/transformer.py:95: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/models/transformer.py:95: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Entity <function attention_bias_to_padding at 0x7fcebdeef8c8> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Entity <function attention_bias_to_padding at 0x7fcebdeef8c8> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: Entity <function attention_bias_to_padding at 0x7fcebdeef8c8> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: \n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/utils/expert_utils.py:621: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/utils/expert_utils.py:621: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Entity <function layers at 0x7fcf17b2e048> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: If not matching a CFG node, must be a block statement: <gast.gast.ImportFrom object at 0x7fceb5c9e828>\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Entity <function layers at 0x7fcf17b2e048> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: If not matching a CFG node, must be a block statement: <gast.gast.ImportFrom object at 0x7fceb5c9e828>\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: Entity <function layers at 0x7fcf17b2e048> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: If not matching a CFG node, must be a block statement: <gast.gast.ImportFrom object at 0x7fceb5c9e828>\n",
      "INFO:tensorflow:Transforming body output with symbol_modality_32100_512.top\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming body output with symbol_modality_32100_512.top\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/models/transformer.py:1258: to_int64 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensor2tensor-1.15.7-py3.6.egg/tensor2tensor/models/transformer.py:1258: to_int64 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using variable initializer: uniform_unit_scaling\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using variable initializer: uniform_unit_scaling\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_32100_512.bottom\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_32100_512.bottom\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'targets' with symbol_modality_32100_512.targets_bottom\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'targets' with symbol_modality_32100_512.targets_bottom\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Building model body\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Building model body\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming body output with symbol_modality_32100_512.top\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming body output with symbol_modality_32100_512.top\n"
     ]
    }
   ],
   "source": [
    "class Model:\n",
    "    def __init__(self, HPARAMS = \"transformer_base\", DATA_DIR = 't2t/data'):\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        \n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype=tf.int32)\n",
    "        self.maxlen_decode = tf.reduce_max(self.X_seq_len)\n",
    "        #self.maxlen_decode = tf.placeholder(tf.int32, None)\n",
    "        \n",
    "        x = tf.expand_dims(tf.expand_dims(self.X, -1), -1)\n",
    "        y = tf.expand_dims(tf.expand_dims(self.Y, -1), -1)\n",
    "        \n",
    "        features = {\n",
    "            \"inputs\": x,\n",
    "            \"targets\": y,\n",
    "            \"target_space_id\": tf.constant(1, dtype=tf.int32),\n",
    "        }\n",
    "        self.features = features\n",
    "        \n",
    "        Modes = tf.estimator.ModeKeys\n",
    "        hparams = trainer_lib.create_hparams(HPARAMS, data_dir=DATA_DIR, problem_name=PROBLEM)\n",
    "        \n",
    "        translate_model = registry.model('transformer')(hparams, Modes.PREDICT)\n",
    "        self.translate_model = translate_model\n",
    "        logits, _ = translate_model(features)\n",
    "        self.logits = logits\n",
    "        \n",
    "        with tf.variable_scope(tf.get_variable_scope(), reuse=True):\n",
    "            self.fast_result = translate_model._greedy_infer(features, self.maxlen_decode)[\"outputs\"]\n",
    "            self.beam_result = translate_model._beam_decode_slow(\n",
    "                features, self.maxlen_decode, beam_size=3, \n",
    "                top_beams=1, alpha=0.5)[\"outputs\"]\n",
    "        \n",
    "        self.fast_result = tf.identity(self.fast_result, name = 'greedy')\n",
    "        self.beam_result = tf.identity(self.beam_result, name = 'beam')\n",
    "        \n",
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from t2t-knowledge-graph/train-base/model.ckpt-200000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from t2t-knowledge-graph/train-base/model.ckpt-200000\n"
     ]
    }
   ],
   "source": [
    "var_lists = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)\n",
    "saver = tf.train.Saver(var_list = var_lists)\n",
    "saver.restore(sess, 't2t-knowledge-graph/train-base/model.ckpt-200000')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "from unidecode import unidecode\n",
    "\n",
    "def cleaning(string):\n",
    "    return re.sub(r'[ ]+', ' ', unidecode(string)).strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "# encoded = encoder.encode(cleaning(\"Ketua negara Malaysia ialah Yang di Pertuan Agong iaitu raja elektif yang terpilih dan diundi dari kalangan sembilan raja negeri Melayu Ketua kerajaannya pula ialah Perdana Menteri.\")) + [1]\n",
    "# f = sess.run([model.nucleus_result], feed_dict = {model.X: [encoded], model.top_p: 0.9})\n",
    "# encoder.decode(f[0][0].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Honored Dato Sri Haji Mohammad Najib bin Tun Haji Abdul Razak occupation Politician, educated at Pekan Pahang, position held Member of the 2018 Parliament of Malaysia, member of 6th Prime Minister of Malaysia, position held Member of the 2018 Parliament of Malaysia, honorific prefix The Imam, position held Member of the 2018 Parliament of Malaysia.'"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoded = encoder.encode(cleaning(\"Yang Berhormat Dato Sri Haji Mohammad Najib bin Tun Haji Abdul Razak ialah ahli politik Malaysia dan merupakan bekas Perdana Menteri Malaysia ke-6 yang mana beliau menjawat jawatan dari 3 April 2009 hingga 9 Mei 2018. Beliau juga pernah berkhidmat sebagai bekas Menteri Kewangan dan merupakan Ahli Parlimen Pekan Pahang\")) + [1]\n",
    "f = sess.run([model.fast_result], feed_dict = {model.X: [encoded]})\n",
    "encoder.decode(f[0][0].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Pahang , Malaysia located in or next to body of water Pahang, shares border with Kelantan, located in or next to body of water Pahang, shares border with Perak Selangor, located in or next to body of water North, shares border with Negeri Sembilan, located in or next to body of water Terengganu.'"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoded = encoder.encode(cleaning(\"Pahang ialah negeri yang ketiga terbesar di Malaysia Terletak di lembangan Sungai Pahang yang amat luas negeri Pahang bersempadan dengan Kelantan di utara Perak Selangor serta Negeri Sembilan di barat Johor di selatan dan Terengganu dan Laut China Selatan di timur.\")) + [1]\n",
    "f = sess.run([model.fast_result], feed_dict = {model.X: [encoded]})\n",
    "encoder.decode(f[0][0].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "files = tf.gfile.Glob('../t2t-ms-en/data/*')\n",
    "len(files)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# batch_size = 1\n",
    "# data_fields = {\n",
    "#     'inputs': tf.VarLenFeature(tf.int64),\n",
    "#     'targets': tf.VarLenFeature(tf.int64),\n",
    "# }\n",
    "# data_len = {\n",
    "#     'inputs': 300,\n",
    "#     'targets': 300,\n",
    "# }\n",
    "\n",
    "# def parse(serialized_example):\n",
    "\n",
    "#     features = tf.parse_single_example(\n",
    "#         serialized_example, features = data_fields\n",
    "#     )\n",
    "#     for k in features.keys():\n",
    "#         features[k] = features[k].values\n",
    "#         features[k] = tf.pad(\n",
    "#             features[k], [[0, data_len[k] - tf.shape(features[k])[0]]]\n",
    "#         )\n",
    "#         features[k].set_shape((data_len[k]))\n",
    "\n",
    "#     return features\n",
    "\n",
    "# def _decode_record(example, name_to_features):\n",
    "#     \"\"\"Decodes a record to a TensorFlow example.\"\"\"\n",
    "\n",
    "#     # tf.Example only supports tf.int64, but the TPU only supports tf.int32.\n",
    "#     # So cast all int64 to int32.\n",
    "#     for name in list(example.keys()):\n",
    "#         t = example[name]\n",
    "#         if t.dtype == tf.int64:\n",
    "#             t = tf.to_int32(t)\n",
    "#         example[name] = t\n",
    "\n",
    "#     return example\n",
    "\n",
    "# d = tf.data.TFRecordDataset(files)\n",
    "# d = d.map(parse, num_parallel_calls = 32)\n",
    "# d = d.apply(\n",
    "#     tf.contrib.data.map_and_batch(\n",
    "#         lambda record: _decode_record(record, data_fields),\n",
    "#         batch_size = batch_size,\n",
    "#         num_parallel_batches = 4,\n",
    "#         drop_remainder = True,\n",
    "#     )\n",
    "# )\n",
    "# d = d.make_one_shot_iterator().get_next()\n",
    "# d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# r_ = sess.run(d)\n",
    "# encoder.decode(r_['inputs'][0].tolist()), encoder.decode(r_['targets'][0].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# encoder.decode(r_['inputs'][2].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# encoder.decode(r_['targets'][2].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'transformer-base/model.ckpt'"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "saver = tf.train.Saver(tf.trainable_variables())\n",
    "saver.save(sess, 'transformer-base/model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "strings = ','.join(\n",
    "    [\n",
    "        n.name\n",
    "        for n in tf.get_default_graph().as_graph_def().node\n",
    "        if ('Variable' in n.op\n",
    "        or 'Placeholder' in n.name\n",
    "        or 'greedy' in n.name\n",
    "        or 'beam' in n.name\n",
    "        or 'alphas' in n.name\n",
    "        or 'self/Softmax' in n.name)\n",
    "        and 'adam' not in n.name\n",
    "        and 'beta' not in n.name\n",
    "        and 'global_step' not in n.name\n",
    "        and 'modality' not in n.name\n",
    "        and 'Assign' not in n.name\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "def freeze_graph(model_dir, output_node_names):\n",
    "\n",
    "    if not tf.gfile.Exists(model_dir):\n",
    "        raise AssertionError(\n",
    "            \"Export directory doesn't exists. Please specify an export \"\n",
    "            'directory: %s' % model_dir\n",
    "        )\n",
    "\n",
    "    checkpoint = tf.train.get_checkpoint_state(model_dir)\n",
    "    input_checkpoint = checkpoint.model_checkpoint_path\n",
    "\n",
    "    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])\n",
    "    output_graph = absolute_model_dir + '/frozen_model.pb'\n",
    "    clear_devices = True\n",
    "    with tf.Session(graph = tf.Graph()) as sess:\n",
    "        saver = tf.train.import_meta_graph(\n",
    "            input_checkpoint + '.meta', clear_devices = clear_devices\n",
    "        )\n",
    "        saver.restore(sess, input_checkpoint)\n",
    "        output_graph_def = tf.graph_util.convert_variables_to_constants(\n",
    "            sess,\n",
    "            tf.get_default_graph().as_graph_def(),\n",
    "            output_node_names.split(','),\n",
    "        )\n",
    "        with tf.gfile.GFile(output_graph, 'wb') as f:\n",
    "            f.write(output_graph_def.SerializeToString())\n",
    "        print('%d ops in the final graph.' % len(output_graph_def.node))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from transformer-base/model.ckpt\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from transformer-base/model.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Froze 201 variables.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Froze 201 variables.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Converted 201 variables to const ops.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Converted 201 variables to const ops.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "11004 ops in the final graph.\n"
     ]
    }
   ],
   "source": [
    "freeze_graph('transformer-base', strings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.tools.graph_transforms import TransformGraph\n",
    "from glob import glob\n",
    "tf.set_random_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow_text\n",
    "import tf_sentencepiece"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "transformer-base/frozen_model.pb\n"
     ]
    }
   ],
   "source": [
    "transforms = ['add_default_attributes',\n",
    "             'remove_nodes(op=Identity, op=CheckNumerics, op=Dropout)',\n",
    "             'fold_batch_norms',\n",
    "             'fold_old_batch_norms',\n",
    "             'quantize_weights(fallback_min=-10, fallback_max=10)',\n",
    "             'strip_unused_nodes',\n",
    "             'sort_by_execution_order']\n",
    "\n",
    "pb = 'transformer-base/frozen_model.pb'\n",
    "input_graph_def = tf.GraphDef()\n",
    "with tf.gfile.FastGFile(pb, 'rb') as f:\n",
    "    input_graph_def.ParseFromString(f.read())\n",
    "\n",
    "print(pb)\n",
    "\n",
    "transformed_graph_def = TransformGraph(input_graph_def, \n",
    "                                       ['Placeholder'],\n",
    "                                       ['greedy', 'beam'], transforms)\n",
    "\n",
    "with tf.gfile.GFile(f'{pb}.quantized', 'wb') as f:\n",
    "    f.write(transformed_graph_def.SerializeToString())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
